[MAX] Add UniPC multistep scheduler for Wan diffusion#13
Conversation
## Summary Add a numpy-only UniPC multistep scheduler for Wan diffusion pipelines. ## Description - Implements the UniPC-BH2 algorithm with corrector and predictor steps - Supports flow-matching sigma schedules (used by Wan 2.1/2.2) - Provides `build_step_coefficients()` to precompute per-step coefficient matrices on the host, enabling on-device scheduler steps without Python-side numpy calls during denoising - Registers `UniPCMultistepScheduler` in the diffusion scheduler factory This is a numpy-only port of the diffusers `UniPCMultistepScheduler`, specialized for the Wan pipeline configuration. ## Dependencies None — can be merged independently. ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: #13, branch: jglee-sqbits/stack/1
7b02fbe to
aca5cea
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces the UniPCMultistepScheduler, a NumPy-only implementation of the UniPC-BH2 algorithm designed for fast sampling in diffusion models like the Wan 2.2 T2V pipeline. The scheduler includes support for flow-matching and pre-computes step coefficients for optimized inference. Review feedback highlights several issues in the coefficient pre-computation methods, specifically that _predictor_coefficients and _corrector_coefficients hardcode parameters such as solver_type and predict_x0, and lack the necessary logic to handle solver orders greater than two.
| b_h = float(np.expm1(-h)) | ||
| sample_scale = float(sigma_t / sigma_s0) | ||
| m0_scale = float(-alpha_t * b_h) | ||
| m1_scale = 0.0 |
There was a problem hiding this comment.
The _predictor_coefficients method hardcodes the calculation for solver_type="bh2" and assumes predict_x0=True. It should respect self.solver_type and self.predict_x0 to ensure consistency with the step() method. Specifically, hh should be determined by predict_x0, and the scale factor for m0 should be alpha_t or sigma_t depending on the prediction type.
| hh = -h | ||
| h_phi_1 = float(np.expm1(hh)) | ||
| b_h = float(np.expm1(hh)) |
| if order == 2: | ||
| sigma_si_raw = float(self.sigmas[step_index - 1]) | ||
| lambda_si = self._lambda_from_sigma(sigma_si_raw) | ||
| rk = (lambda_si - lambda_s0) / h | ||
| m1_scale = float(-alpha_t * b_h * 0.5 / rk) | ||
| m0_scale -= m1_scale | ||
|
|
There was a problem hiding this comment.
_predictor_coefficients only supports up to order=2. If solver_order is set to a higher value, this method will return incorrect coefficients (effectively falling back to order 1 for the predictor part). It should be generalized to handle arbitrary orders using a linear system solver, similar to the implementation in multistep_uni_p_bh_update.
| m1_scale = float(-alpha_t * b_h * rhos_c[0] / rk) | ||
| m0_scale = float( | ||
| -alpha_t * h_phi_1 + alpha_t * b_h * (rhos_c[0] / rk + rhos_c[-1]) | ||
| ) | ||
| mt_scale = float(-alpha_t * b_h * rhos_c[-1]) |
There was a problem hiding this comment.
The coefficient calculation for m0_scale and m1_scale in _corrector_coefficients only accounts for a single rk value (rhos_c[0] / rk). For order > 2, there are multiple history terms with different rk values that must be accounted for in the summation. This will lead to incorrect results when using higher solver orders.
Stacked PRs:
[MAX] Add UniPC multistep scheduler for Wan diffusion
Summary
Add a numpy-only UniPC multistep scheduler for Wan diffusion pipelines.
Description
build_step_coefficients()to precompute per-step coefficient matrices on the host, enabling on-device scheduler steps without Python-side numpy calls during denoisingUniPCMultistepSchedulerin the diffusion scheduler factoryThis is a numpy-only port of the diffusers
UniPCMultistepScheduler, specialized for the Wan pipeline configuration.Dependencies
None — can be merged independently.
Checklist
./bazelw run formatto format my changesAssisted-by: Claude Code
Assisted-by: Claude Code